{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Keyword-Extraction using BERT"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Use BERT Token Classification Model to extract keyword tokens from a sentence."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prepare Dataset for BERT.\n",
    "\n",
    "Convert Sem-Eval 2010 keyword recognition dataset to BIO format dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from nltk.tokenize import sent_tokenize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_path = \"maui-semeval2010-train\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "txt = sorted([f for f in os.listdir(train_path) if not f.endswith(\"-justTitle.txt\") and not f.endswith(\".key\") and not f.endswith(\"-CrowdCountskey\")])\n",
    "key = sorted([f for f in os.listdir(train_path) if f.endswith(\".key\")])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "filekey = dict()\n",
    "for i, k in enumerate(txt):\n",
    "    filekey[key[i]] = k"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def convert(key):\n",
    "    sentences = \"\"\n",
    "    for line in open(train_path + \"/\" + filekey[key], 'r'):\n",
    "        sentences += (\" \" + line.rstrip())\n",
    "    tokens = sent_tokenize(sentences)\n",
    "    key_file = open(train_path + \"/\" + str(key),'r')\n",
    "    keys = [line.strip() for line in key_file]\n",
    "    key_sent = []\n",
    "    labels = []\n",
    "    for token in tokens:\n",
    "        z = ['O'] * len(token.split())\n",
    "        for k in keys:\n",
    "            if k in token:\n",
    "                \n",
    "                if len(k.split())==1:\n",
    "                    try:\n",
    "                        z[token.lower().split().index(k.lower().split()[0])] = 'B'\n",
    "                    except ValueError:\n",
    "                        continue\n",
    "                elif len(k.split())>1:\n",
    "                    try:\n",
    "                        if token.lower().split().index(k.lower().split()[0]) and token.lower().split().index(k.lower().split()[-1]):\n",
    "                            z[token.lower().split().index(k.lower().split()[0])] = 'B'\n",
    "                            for j in range(1, len(k.split())):\n",
    "                                z[token.lower().split().index(k.lower().split()[j])] = 'I'\n",
    "                    except ValueError:\n",
    "                        continue\n",
    "        for m, n in enumerate(z):\n",
    "            if z[m] == 'I' and z[m-1] == 'O':\n",
    "                z[m] = 'O'\n",
    "\n",
    "        if set(z) != {'O'}:\n",
    "            labels.append(z) \n",
    "            key_sent.append(token)\n",
    "    return key_sent, labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "sentences_ = []\n",
    "labels_ = []\n",
    "for key, value in filekey.items():\n",
    "    s, l = convert(key)\n",
    "    sentences_.append(s)\n",
    "    labels_.append(l)\n",
    "sentences = [item for sublist in sentences_ for item in sublist]\n",
    "labels = [item for sublist in labels_ for item in sublist]\n",
    "# print(len(sentences), len(labels))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from tqdm import tqdm, trange\n",
    "import torch\n",
    "from torch.optim import Adam\n",
    "from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler\n",
    "from torch.nn.utils.rnn import pad_sequence\n",
    "from keras.preprocessing.sequence import pad_sequences\n",
    "from sklearn.model_selection import train_test_split\n",
    "from pytorch_pretrained_bert import BertTokenizer, BertConfig\n",
    "from pytorch_pretrained_bert import BertForTokenClassification, BertAdam"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "MAX_LEN = 75\n",
    "bs = 32"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "tag2idx = {'B': 0, 'I': 1, 'O': 2}\n",
    "tags_vals = ['B', 'I', 'O']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "n_gpu = torch.cuda.device_count()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenized_texts = [tokenizer.tokenize(sent) for sent in sentences]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_ids = pad_sequences([tokenizer.convert_tokens_to_ids(txt) for txt in tokenized_texts],\n",
    "                          maxlen=MAX_LEN, dtype=\"long\", truncating=\"post\", padding=\"post\")\n",
    "tags = pad_sequences([[tag2idx.get(l) for l in lab] for lab in labels],\n",
    "                     maxlen=MAX_LEN, value=tag2idx[\"O\"], padding=\"post\",\n",
    "                     dtype=\"long\", truncating=\"post\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "attention_masks = [[float(i>0) for i in ii] for ii in input_ids]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "tr_inputs, val_inputs, tr_tags, val_tags = train_test_split(input_ids, tags, \n",
    "                                                            random_state=2018, test_size=0.1)\n",
    "tr_masks, val_masks, _, _ = train_test_split(attention_masks, input_ids,\n",
    "                                             random_state=2018, test_size=0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "tr_inputs = torch.tensor(tr_inputs)\n",
    "val_inputs = torch.tensor(val_inputs)\n",
    "tr_tags = torch.tensor(tr_tags)\n",
    "val_tags = torch.tensor(val_tags)\n",
    "tr_masks = torch.tensor(tr_masks)\n",
    "val_masks = torch.tensor(val_masks)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data = TensorDataset(tr_inputs, tr_masks, tr_tags)\n",
    "train_sampler = RandomSampler(train_data)\n",
    "train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=bs)\n",
    "\n",
    "valid_data = TensorDataset(val_inputs, val_masks, val_tags)\n",
    "valid_sampler = SequentialSampler(valid_data)\n",
    "valid_dataloader = DataLoader(valid_data, sampler=valid_sampler, batch_size=bs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = BertForTokenClassification.from_pretrained(\"bert-base-uncased\", num_labels=len(tag2idx))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = model.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "FULL_FINETUNING = True\n",
    "if FULL_FINETUNING:\n",
    "    param_optimizer = list(model.named_parameters())\n",
    "    no_decay = ['bias', 'gamma', 'beta']\n",
    "    optimizer_grouped_parameters = [\n",
    "        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],\n",
    "         'weight_decay_rate': 0.01},\n",
    "        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],\n",
    "         'weight_decay_rate': 0.0}\n",
    "    ]\n",
    "else:\n",
    "    param_optimizer = list(model.classifier.named_parameters()) \n",
    "    optimizer_grouped_parameters = [{\"params\": [p for n, p in param_optimizer]}]\n",
    "optimizer = Adam(optimizer_grouped_parameters, lr=3e-5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "from seqeval.metrics import f1_score\n",
    "\n",
    "def flat_accuracy(preds, labels):\n",
    "    pred_flat = np.argmax(preds, axis=2).flatten()\n",
    "    labels_flat = labels.flatten()\n",
    "    return np.sum(pred_flat == labels_flat) / len(labels_flat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Epoch:   0%|          | 0/4 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train loss: 0.09159750645318307\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Epoch:  25%|██▌       | 1/4 [03:56<11:50, 236.93s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation loss: 0.07012036180606594\n",
      "Validation Accuracy: 0.9785305212620027\n",
      "F1-Score: 0.2601214574898785\n",
      "Train loss: 0.06027003754823169\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Epoch:  50%|█████     | 2/4 [07:55<07:54, 237.36s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation loss: 0.05512299305862851\n",
      "Validation Accuracy: 0.9826303155006857\n",
      "F1-Score: 0.4358124111795358\n",
      "Train loss: 0.04443128131453163\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Epoch:  75%|███████▌  | 3/4 [11:53<03:57, 237.66s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation loss: 0.04628589225036127\n",
      "Validation Accuracy: 0.9847713763145862\n",
      "F1-Score: 0.5485906604964241\n",
      "Train loss: 0.03500022632185339\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Epoch: 100%|██████████| 4/4 [15:51<00:00, 237.85s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation loss: 0.04753524796278388\n",
      "Validation Accuracy: 0.9848725422953818\n",
      "F1-Score: 0.5591739475774424\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "epochs = 4\n",
    "max_grad_norm = 1.0\n",
    "\n",
    "for _ in trange(epochs, desc=\"Epoch\"):\n",
    "    # TRAIN loop\n",
    "    model.train()\n",
    "    tr_loss = 0\n",
    "    nb_tr_examples, nb_tr_steps = 0, 0\n",
    "    for step, batch in enumerate(train_dataloader):\n",
    "        # add batch to gpu\n",
    "        batch = tuple(t.to(device) for t in batch)\n",
    "        b_input_ids, b_input_mask, b_labels = batch\n",
    "        # forward pass\n",
    "        loss = model(b_input_ids, token_type_ids=None,\n",
    "                     attention_mask=b_input_mask, labels=b_labels)\n",
    "        # backward pass\n",
    "        loss.backward()\n",
    "        # track train loss\n",
    "        tr_loss += loss.item()\n",
    "        nb_tr_examples += b_input_ids.size(0)\n",
    "        nb_tr_steps += 1\n",
    "        # gradient clipping\n",
    "        torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=max_grad_norm)\n",
    "        # update parameters\n",
    "        optimizer.step()\n",
    "        model.zero_grad()\n",
    "    # print train loss per epoch\n",
    "    print(\"Train loss: {}\".format(tr_loss/nb_tr_steps))\n",
    "    # VALIDATION on validation set\n",
    "    model.eval()\n",
    "    eval_loss, eval_accuracy = 0, 0\n",
    "    nb_eval_steps, nb_eval_examples = 0, 0\n",
    "    predictions , true_labels = [], []\n",
    "    for batch in valid_dataloader:\n",
    "        batch = tuple(t.to(device) for t in batch)\n",
    "        b_input_ids, b_input_mask, b_labels = batch\n",
    "        \n",
    "        with torch.no_grad():\n",
    "            tmp_eval_loss = model(b_input_ids, token_type_ids=None,\n",
    "                                  attention_mask=b_input_mask, labels=b_labels)\n",
    "            logits = model(b_input_ids, token_type_ids=None,\n",
    "                           attention_mask=b_input_mask)\n",
    "        logits = logits.detach().cpu().numpy()\n",
    "        label_ids = b_labels.to('cpu').numpy()\n",
    "        predictions.extend([list(p) for p in np.argmax(logits, axis=2)])\n",
    "        true_labels.append(label_ids)\n",
    "        \n",
    "        tmp_eval_accuracy = flat_accuracy(logits, label_ids)\n",
    "        \n",
    "        eval_loss += tmp_eval_loss.mean().item()\n",
    "        eval_accuracy += tmp_eval_accuracy\n",
    "        \n",
    "        nb_eval_examples += b_input_ids.size(0)\n",
    "        nb_eval_steps += 1\n",
    "    eval_loss = eval_loss/nb_eval_steps\n",
    "    print(\"Validation loss: {}\".format(eval_loss))\n",
    "    print(\"Validation Accuracy: {}\".format(eval_accuracy/nb_eval_steps))\n",
    "    pred_tags = [tags_vals[p_i] for p in predictions for p_i in p]\n",
    "    valid_tags = [tags_vals[l_ii] for l in true_labels for l_i in l for l_ii in l_i]\n",
    "    print(\"F1-Score: {}\".format(f1_score(pred_tags, valid_tags)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation loss: 0.04753524796278388\n",
      "Validation Accuracy: 0.9848725422953818\n",
      "Validation F1-Score: 0.5591739475774424\n"
     ]
    }
   ],
   "source": [
    "model.eval()\n",
    "predictions = []\n",
    "true_labels = []\n",
    "eval_loss, eval_accuracy = 0, 0\n",
    "nb_eval_steps, nb_eval_examples = 0, 0\n",
    "for batch in valid_dataloader:\n",
    "    batch = tuple(t.to(device) for t in batch)\n",
    "    b_input_ids, b_input_mask, b_labels = batch\n",
    "\n",
    "    with torch.no_grad():\n",
    "        tmp_eval_loss = model(b_input_ids, token_type_ids=None,\n",
    "                              attention_mask=b_input_mask, labels=b_labels)\n",
    "        logits = model(b_input_ids, token_type_ids=None,\n",
    "                       attention_mask=b_input_mask)\n",
    "        \n",
    "    logits = logits.detach().cpu().numpy()\n",
    "    predictions.extend([list(p) for p in np.argmax(logits, axis=2)])\n",
    "\n",
    "    label_ids = b_labels.to('cpu').numpy()\n",
    "    true_labels.append(label_ids)\n",
    "    tmp_eval_accuracy = flat_accuracy(logits, label_ids)\n",
    "\n",
    "    eval_loss += tmp_eval_loss.mean().item()\n",
    "    eval_accuracy += tmp_eval_accuracy\n",
    "\n",
    "    nb_eval_examples += b_input_ids.size(0)\n",
    "    nb_eval_steps += 1\n",
    "\n",
    "pred_tags = [[tags_vals[p_i] for p_i in p] for p in predictions]\n",
    "valid_tags = [[tags_vals[l_ii] for l_ii in l_i] for l in true_labels for l_i in l ]\n",
    "print(\"Validation loss: {}\".format(eval_loss/nb_eval_steps))\n",
    "print(\"Validation Accuracy: {}\".format(eval_accuracy/nb_eval_steps))\n",
    "print(\"Validation F1-Score: {}\".format(f1_score(pred_tags, valid_tags)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Get keywords from sentence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "def keywordextract(sentence):\n",
    "    text = sentence\n",
    "    tkns = tokenizer.tokenize(text)\n",
    "    indexed_tokens = tokenizer.convert_tokens_to_ids(tkns)\n",
    "    segments_ids = [0] * len(tkns)\n",
    "    tokens_tensor = torch.tensor([indexed_tokens]).to(device)\n",
    "    segments_tensors = torch.tensor([segments_ids]).to(device)\n",
    "    model.eval()\n",
    "    prediction = []\n",
    "    logit = model(tokens_tensor, token_type_ids=None,\n",
    "                                  attention_mask=segments_tensors)\n",
    "    logit = logit.detach().cpu().numpy()\n",
    "    prediction.extend([list(p) for p in np.argmax(logit, axis=2)])\n",
    "    for k, j in enumerate(prediction[0]):\n",
    "        if j==1 or j==0:\n",
    "            print(tokenizer.convert_ids_to_tokens(tokens_tensor[0].to('cpu').numpy())[k], j)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "text = \"The solution is based upon an abstract representation of the mobile object system.\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mobile 0\n"
     ]
    }
   ],
   "source": [
    "keywordextract(text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Environment (conda_pytorch_p36)",
   "language": "python",
   "name": "conda_pytorch_p36"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
